#!/usr/bin/python
from __future__ import print_function
import sys
from sympy import Symbol,symbols,sin,cos,Rational,expand,simplify,collect
from galgebra.printer import enhance_print,Format
from galgebra.deprecated import MV
from galgebra.mv import Nga,ONE,ZERO
from galgebra.ga import Ga

def basic_multivector_operations():
    (ex,ey,ez) = MV.setup('e*x|y|z')

    A = MV('A','mv')

    A.Fmt(1,'A')
    A.Fmt(2,'A')
    A.Fmt(3,'A')

    X = MV('X','vector')
    Y = MV('Y','vector')

    print('g_{ij} =\n',MV.metric)

    X.Fmt(1,'X')
    Y.Fmt(1,'Y')

    (X*Y).Fmt(2,'X*Y')
    (X^Y).Fmt(2,'X^Y')
    (X|Y).Fmt(2,'X|Y')

    (ex,ey) = MV.setup('e*x|y')

    print('g_{ij} =\n',MV.metric)

    X = MV('X','vector')
    A = MV('A','spinor')

    X.Fmt(1,'X')
    A.Fmt(1,'A')

    (X|A).Fmt(2,'X|A')
    (X<A).Fmt(2,'X<A')
    (A>X).Fmt(2,'A>X')

    (ex,ey) = MV.setup('e*x|y',metric='[1,1]')

    print('g_{ii} =\n',MV.metric)

    X = MV('X','vector')
    A = MV('A','spinor')

    X.Fmt(1,'X')
    A.Fmt(1,'A')

    (X*A).Fmt(2,'X*A')
    (X|A).Fmt(2,'X|A')
    (X<A).Fmt(2,'X<A')
    (X>A).Fmt(2,'X>A')

    (A*X).Fmt(2,'A*X')
    (A|X).Fmt(2,'A|X')
    (A<X).Fmt(2,'A<X')
    (A>X).Fmt(2,'A>X')
    return

def check_generalized_BAC_CAB_formulas():

    (a,b,c,d,e) = MV.setup('a b c d e')

    print('g_{ij} =\n',MV.metric)

    print('a|(b*c) =',a|(b*c))
    print('a|(b^c) =',a|(b^c))
    print('a|(b^c^d) =',a|(b^c^d))
    print('a|(b^c)+c|(a^b)+b|(c^a) =',(a|(b^c))+(c|(a^b))+(b|(c^a)))
    print('a*(b^c)-b*(a^c)+c*(a^b) =',a*(b^c)-b*(a^c)+c*(a^b))
    print('a*(b^c^d)-b*(a^c^d)+c*(a^b^d)-d*(a^b^c) =',a*(b^c^d)-b*(a^c^d)+c*(a^b^d)-d*(a^b^c))
    print('(a^b)|(c^d) =',(a^b)|(c^d))
    print('((a^b)|c)|d =',((a^b)|c)|d)
    print('(a^b)x(c^d) =',Ga.com(a^b,c^d))
    print('(a|(b^c))|(d^e) =',(a|(b^c))|(d^e))

    return

def derivatives_in_rectangular_coordinates():

    X = (x,y,z) = symbols('x y z')
    (ex,ey,ez,grad) = MV.setup('e_x e_y e_z',metric='[1,1,1]',coords=X)

    f = MV('f','scalar',fct=True)
    A = MV('A','vector',fct=True)
    B = MV('B','grade2',fct=True)
    C = MV('C','mv',fct=True)
    print('f =',f)
    print('A =',A)
    print('B =',B)
    print('C =',C)

    print('grad*f =',grad*f)
    print('grad|A =',grad|A)
    print('grad*A =',grad*A)

    print('-I*(grad^A) =',-MV.I*(grad^A))
    print('grad*B =',grad*B)
    print('grad^B =',grad^B)
    print('grad|B =',grad|B)

    print('grad<A =',grad<A)
    print('grad>A =',grad>A)
    print('grad<B =',grad<B)
    print('grad>B =',grad>B)
    print('grad<C =',grad<C)
    print('grad>C =',grad>C)

    return

def derivatives_in_spherical_coordinates():

    X = (r,th,phi) = symbols('r theta phi')
    curv = [[r*cos(phi)*sin(th),r*sin(phi)*sin(th),r*cos(th)],[1,r,r*sin(th)]]
    (er,eth,ephi,grad) = MV.setup('e_r e_theta e_phi',metric='[1,1,1]',coords=X,curv=curv)

    f = MV('f','scalar',fct=True)
    A = MV('A','vector',fct=True)
    B = MV('B','grade2',fct=True)

    print('f =',f)
    print('A =',A)
    print('B =',B)

    print('grad*f =',grad*f)
    print('grad|A =',grad|A)
    print('-I*(grad^A) =',-MV.I*(grad^A))
    print('grad^B =',grad^B)
    return

def rounding_numerical_components():

    (ex,ey,ez) = MV.setup('e_x e_y e_z',metric='[1,1,1]')

    X = 1.2*ex+2.34*ey+0.555*ez
    Y = 0.333*ex+4*ey+5.3*ez

    print('X =',X)
    print('Nga(X,2) =',Nga(X,2))
    print('X*Y =',X*Y)
    print('Nga(X*Y,2) =',Nga(X*Y,2))
    return

def noneuclidian_distance_calculation():
    from sympy import solve,sqrt

    metric = '0 # #,# 0 #,# # 1'
    (X,Y,e) = MV.setup('X Y e',metric)

    print('g_{ij} =',MV.metric)

    print('(X^Y)**2 =',(X^Y)*(X^Y))

    L = X^Y^e
    B = L*e # D&L 10.152
    print('B =',B)
    Bsq = B*B
    print('B**2 =',Bsq)
    Bsq = Bsq.scalar()
    print('#L = X^Y^e is a non-euclidian line')
    print('B = L*e =',B)

    BeBr =B*e*B.rev()
    print('B*e*B.rev() =',BeBr)
    print('B**2 =',B*B)
    print('L**2 =',L*L) # D&L 10.153
    (s,c,Binv,M,S,C,alpha,XdotY,Xdote,Ydote) = symbols('s c (1/B) M S C alpha (X.Y) (X.e) (Y.e)')

    Bhat = Binv*B # D&L 10.154
    R = c+s*Bhat # Rotor R = exp(alpha*Bhat/2)
    print('s = sinh(alpha/2) and c = cosh(alpha/2)')
    print('exp(alpha*B/(2*|B|)) =',R)

    Z = R*X*R.rev() # D&L 10.155
    Z.obj = expand(Z.obj)
    Z.obj = Z.obj.collect([Binv,s,c,XdotY])
    Z.Fmt(3,'R*X*R.rev()')
    W = Z|Y # Extract scalar part of multivector
    # From this point forward all calculations are with sympy scalars
    print('Objective is to determine value of C = cosh(alpha) such that W = 0')
    W = W.scalar()
    print('Z|Y =',W)
    W = expand(W)
    W = simplify(W)
    W = W.collect([s*Binv])

    M = 1/Bsq
    W = W.subs(Binv**2,M)
    W = simplify(W)
    Bmag = sqrt(XdotY**2-2*XdotY*Xdote*Ydote)
    W = W.collect([Binv*c*s,XdotY])

    #Double angle substitutions

    W = W.subs(2*XdotY**2-4*XdotY*Xdote*Ydote,2/(Binv**2))
    W = W.subs(2*c*s,S)
    W = W.subs(c**2,(C+1)/2)
    W = W.subs(s**2,(C-1)/2)
    W = simplify(W)
    W = W.subs(1/Binv,Bmag)
    W = expand(W)

    print('S = sinh(alpha) and C = cosh(alpha)')

    print('W =',W)

    Wd = collect(W,[C,S],exact=True,evaluate=False)

    Wd_1 = Wd[ONE]
    Wd_C = Wd[C]
    Wd_S = Wd[S]

    print('Scalar Coefficient =',Wd_1)
    print('Cosh Coefficient =',Wd_C)
    print('Sinh Coefficient =',Wd_S)

    print('|B| =',Bmag)
    Wd_1 = Wd_1.subs(Bmag,1/Binv)
    Wd_C = Wd_C.subs(Bmag,1/Binv)
    Wd_S = Wd_S.subs(Bmag,1/Binv)

    lhs = Wd_1+Wd_C*C
    rhs = -Wd_S*S
    lhs = lhs**2
    rhs = rhs**2
    W = expand(lhs-rhs)
    W = expand(W.subs(1/Binv**2,Bmag**2))
    W = expand(W.subs(S**2,C**2-1))
    W = W.collect([C,C**2],evaluate=False)

    a = simplify(W[C**2])
    b = simplify(W[C])
    c = simplify(W[ONE])

    print('Require a*C**2+b*C+c = 0')

    print('a =',a)
    print('b =',b)
    print('c =',c)

    x = Symbol('x')
    C =  solve(a*x**2+b*x+c,x)[0]
    print('cosh(alpha) = C = -b/(2*a) =',expand(simplify(expand(C))))
    return

HALF = Rational(1,2)

def F(x):
    global n,nbar
    Fx = HALF*((x*x)*n+2*x-nbar)
    return(Fx)

def make_vector(a,n = 3):
    if isinstance(a,str):
        sym_str = ''
        for i in range(n):
            sym_str += a+str(i+1)+' '
        sym_lst = list(symbols(sym_str))
        sym_lst.append(ZERO)
        sym_lst.append(ZERO)
        a = MV(sym_lst,'vector')
    return(F(a))

def conformal_representations_of_circles_lines_spheres_and_planes():
    global n,nbar

    metric = '1 0 0 0 0,0 1 0 0 0,0 0 1 0 0,0 0 0 0 2,0 0 0 2 0'

    (e1,e2,e3,n,nbar) = MV.setup('e_1 e_2 e_3 n nbar',metric)

    print('g_{ij} =\n',MV.metric)

    e = n+nbar
    #conformal representation of points

    A = make_vector(e1)    # point a = (1,0,0)  A = F(a)
    B = make_vector(e2)    # point b = (0,1,0)  B = F(b)
    C = make_vector(-e1)   # point c = (-1,0,0) C = F(c)
    D = make_vector(e3)    # point d = (0,0,1)  D = F(d)
    X = make_vector('x',3)

    print('F(a) =',A)
    print('F(b) =',B)
    print('F(c) =',C)
    print('F(d) =',D)
    print('F(x) =',X)

    print('a = e1, b = e2, c = -e1, and d = e3')
    print('A = F(a) = 1/2*(a*a*n+2*a-nbar), etc.')
    print('Circle through a, b, and c')
    print('Circle: A^B^C^X = 0 =',(A^B^C^X))
    print('Line through a and b')
    print('Line  : A^B^n^X = 0 =',(A^B^n^X))
    print('Sphere through a, b, c, and d')
    print('Sphere: A^B^C^D^X = 0 =',(((A^B)^C)^D)^X)
    print('Plane through a, b, and d')
    print('Plane : A^B^n^D^X = 0 =',(A^B^n^D^X))

    L = (A^B^e)^X

    L.Fmt(3,'Hyperbolic Circle: (A^B^e)^X = 0 =')
    return

def properties_of_geometric_objects():
    global n,nbar

    metric = '# # # 0 0,'+ \
             '# # # 0 0,'+ \
             '# # # 0 0,'+ \
             '0 0 0 0 2,'+ \
             '0 0 0 2 0'

    (p1,p2,p3,n,nbar) = MV.setup('p1 p2 p3 n nbar',metric)

    print('g_{ij} =\n',MV.metric)

    P1 = F(p1)
    P2 = F(p2)
    P3 = F(p3)

    print('Extracting direction of line from L = P1^P2^n')

    L = P1^P2^n
    delta = (L|n)|nbar
    print('(L|n)|nbar =',delta)

    print('Extracting plane of circle from C = P1^P2^P3')

    C = P1^P2^P3
    delta = ((C^n)|n)|nbar
    print('((C^n)|n)|nbar =',delta)
    print('(p2-p1)^(p3-p1) =',(p2-p1)^(p3-p1))

def extracting_vectors_from_conformal_2_blade():
    global n,nbar

    metric = ' 0 -1 #,'+ \
             '-1 0 #,'+ \
             ' # # #'

    (P1,P2,a) = MV.setup('P1 P2 a',metric)

    print('g_{ij} =\n',MV.metric)

    B = P1^P2
    Bsq = B*B
    print('B**2 =',Bsq)
    ap = a-(a^B)*B
    print("a' = a-(a^B)*B =",ap)

    Ap = ap+ap*B
    Am = ap-ap*B

    print("A+ = a'+a'*B =",Ap)
    print("A- = a'-a'*B =",Am)

    print('(A+)^2 =',Ap*Ap)
    print('(A-)^2 =',Am*Am)

    aB = a|B
    print('a|B =',aB)
    return

def reciprocal_frame_test():

    metric = '1 # #,'+ \
             '# 1 #,'+ \
             '# # 1'

    (e1,e2,e3) = MV.setup('e1 e2 e3',metric)

    print('g_{ij} =\n',MV.metric)

    E = e1^e2^e3
    Esq = (E*E).scalar()
    print('E =',E)
    print('E**2 =',Esq)
    Esq_inv = 1/Esq

    E1 = (e2^e3)*E
    E2 = (-1)*(e1^e3)*E
    E3 = (e1^e2)*E

    print('E1 = (e2^e3)*E =',E1)
    print('E2 =-(e1^e3)*E =',E2)
    print('E3 = (e1^e2)*E =',E3)

    w = (E1|e2)
    w = w.expand()
    print('E1|e2 =',w)

    w = (E1|e3)
    w = w.expand()
    print('E1|e3 =',w)

    w = (E2|e1)
    w = w.expand()
    print('E2|e1 =',w)

    w = (E2|e3)
    w = w.expand()
    print('E2|e3 =',w)

    w = (E3|e1)
    w = w.expand()
    print('E3|e1 =',w)

    w = (E3|e2)
    w = w.expand()
    print('E3|e2 =',w)

    w = (E1|e1)
    w = (w.expand()).scalar()
    Esq = expand(Esq)
    print('(E1|e1)/E**2 =',simplify(w/Esq))

    w = (E2|e2)
    w = (w.expand()).scalar()
    print('(E2|e2)/E**2 =',simplify(w/Esq))

    w = (E3|e3)
    w = (w.expand()).scalar()
    print('(E3|e3)/E**2 =',simplify(w/Esq))
    return

def main():
    enhance_print()
    basic_multivector_operations()
    check_generalized_BAC_CAB_formulas()
    derivatives_in_rectangular_coordinates()
    derivatives_in_spherical_coordinates()
    rounding_numerical_components()
    #noneuclidian_distance_calculation()
    conformal_representations_of_circles_lines_spheres_and_planes()
    properties_of_geometric_objects()
    extracting_vectors_from_conformal_2_blade()
    reciprocal_frame_test()
    return

if __name__ == "__main__":
    main()